{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "import tensorflow as tf\n",
    "import cv2\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "from architecture import shufflenet\n",
    "from input_pipeline import resize_keeping_aspect_ratio, central_crop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "VAL_IMAGES_PATH = '/home/gpu2/hdd/dan/imagenet/val/'\n",
    "DEVKIT_PATH = '/home/gpu2/hdd/dan/imagenet/ILSVRC2012_devkit_t12/'\n",
    "CHECKPOINT_PATH = 'run01/model.ckpt-1661328'\n",
    "USE_MOVING_AVERAGE = True\n",
    "DEPTH_MULTIPLIER = '1.0'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get labels for validation images\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(DEVKIT_PATH, 'data/ILSVRC2012_validation_ground_truth.txt')) as f:\n",
    "    content = f.readlines()\n",
    "\n",
    "content = [int(s.strip()) - 1 for s in content]  \n",
    "# -1 is because the network predicts labels from 0 to 999\n",
    "len(content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata = []\n",
    "for name in os.listdir(VAL_IMAGES_PATH):\n",
    "    number = name.split('_')[-1].split('.')[0]\n",
    "    number = int(number.lstrip())\n",
    "    label = content[number - 1]\n",
    "    # -1 is because the images numbered from 1 to 50000\n",
    "    metadata.append((os.path.join(VAL_IMAGES_PATH, name), label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load a graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "raw_image = tf.placeholder(tf.uint8, [None, None, 3])\n",
    "\n",
    "# preprocessing: resize and do central crop\n",
    "MIN_DIMENSION = 256\n",
    "IMAGE_SIZE = 224\n",
    "image = (1.0 / 255.0) * tf.to_float(raw_image)\n",
    "image = resize_keeping_aspect_ratio(image, MIN_DIMENSION)\n",
    "image = central_crop(image, crop_height=IMAGE_SIZE, crop_width=IMAGE_SIZE)\n",
    "image = tf.expand_dims(image, 0)\n",
    "\n",
    "logits = shufflenet(image, is_training=False, depth_multiplier=DEPTH_MULTIPLIER)\n",
    "probabilities = tf.nn.softmax(logits, axis=1)[0]\n",
    "\n",
    "if USE_MOVING_AVERAGE:\n",
    "    ema = tf.train.ExponentialMovingAverage(decay=0.995)\n",
    "    variables_to_restore = ema.variables_to_restore()\n",
    "    saver = tf.train.Saver(variables_to_restore)\n",
    "else:\n",
    "    saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, CHECKPOINT_PATH)\n",
    "    for path, label in tqdm(metadata):\n",
    "        \n",
    "        image_array = cv2.imread(path)\n",
    "        assert image_array.shape[2] == 3\n",
    "        image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)\n",
    "\n",
    "        feed_dict = {raw_image: image_array}\n",
    "        result = sess.run(probabilities, feed_dict)\n",
    "        top_predicted_labels = np.argsort(result)[-5:]\n",
    "        \n",
    "        results.append((\n",
    "            top_predicted_labels[-1] == label, \n",
    "            (top_predicted_labels == label).any()\n",
    "        ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy = sum(t[0] for t in results)/50000\n",
    "top5_accuracy = sum(t[1] for t in results)/50000\n",
    "\n",
    "print(accuracy, top5_accuracy)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
